{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  Evaluating with CMMD using EvaluationAgent"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/PrunaAI/pruna/blob/v|version|/docs/tutorials/evaluation_agent_cmmd.ipynb\">\n",
    "    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial demonstrates how to use the `pruna` package to evaluate a model. We will use the `sdxl-turbo` model and a subset of the `LAION256` dataset as an example. Any execution times given below are measured on a T4 GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if you are not running the latest version of this tutorial, make sure to install the matching version of pruna\n",
    "# the following command will install the latest version of pruna\n",
    "%pip install pruna"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Loading the Stable Diffusion Model\n",
    "\n",
    "First, load your model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import AutoPipelineForText2Image\n",
    "\n",
    "from pruna.engine.pruna_model import PrunaModel\n",
    "\n",
    "pipe = AutoPipelineForText2Image.from_pretrained(\"stabilityai/sdxl-turbo\")\n",
    "model = PrunaModel(pipe)\n",
    "pipe.set_progress_bar_config(disable=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  2. Create Metrics \n",
    "\n",
    "`pruna` allows you to pass your metrics requests in 3 ways: \n",
    "\n",
    "1. As a plain text request from predefined options (e.g., `image_generation_quality`)\n",
    "\n",
    "2. As a list of metric names \n",
    "\n",
    "3. As a list of metric instances\n",
    "\n",
    "Options 1 and 2 uses the default settings for each metric. For full control over the metric class use option 3.\n",
    "\n",
    "The default `call_type` for `cmmd` is `single`. This means that the metric will produce a score for each model. To create one comparison score between two models, set `call_type` to `pairwise`.\n",
    "\n",
    "To learn more about `single` and `pairwise`, please refer to `pruna` [documentation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html#metric-call-types).\n",
    "\n",
    "In this example we will use `cmmd` as our evaluation metric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Option 1: Using a simple string (default = single mode) ---\n",
    "# request = \"image_generation_quality\"\n",
    "\n",
    "\n",
    "# --- Option 2: Using a simple string (default = single mode) ---\n",
    "request = [\"cmmd\"]\n",
    "\n",
    "# --- Option 3: Full control using the class ---\n",
    "# from pruna.evaluation.metrics import CMMD\n",
    "# request = [CMMD()]  # For single mode\n",
    "# request = [CMMD(call_type=\"pairwise\")]  # For pairwise mode"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Create an EvaluationAgent and a Task with metrics request\n",
    "\n",
    "Pruna's evaluation process uses a Task to define which metrics to calculate and provide the evaluation data. The EvaluationAgent then takes this Task and handles running the model inference, passing the inputs, ground truth, and predictions to each metric, and collecting the results.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruna.data.pruna_datamodule import PrunaDataModule\n",
    "from pruna.evaluation.evaluation_agent import EvaluationAgent\n",
    "from pruna.evaluation.task import Task\n",
    "\n",
    "datamodule = PrunaDataModule.from_string(\"LAION256\")\n",
    "# If you would like to limit the number of samples to evaluate, uncomment the following line\n",
    "# datamodule.limit_datasets(10)\n",
    "task = Task(request, datamodule)\n",
    "eval_agent = EvaluationAgent(task)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Evaluate the first model\n",
    "\n",
    "We can evaluate the first model even before smashing.\n",
    "\n",
    "This is done by calling the `evaluate` method of the EvaluationAgent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optional: tweak model generation parameters for benchmarking\n",
    "model.inference_handler.model_args.update({\"num_inference_steps\": 1, \"guidance_scale\": 0.0})\n",
    "\n",
    "base_results = eval_agent.evaluate(model)\n",
    "print(base_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Smash the model\n",
    "\n",
    "Smash the model as usual."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "from pruna import smash\n",
    "from pruna.config.smash_config import SmashConfig\n",
    "from pruna.engine.utils import safe_memory_cleanup\n",
    "\n",
    "smash_config = SmashConfig([\"deepcache\"])\n",
    "\n",
    "\n",
    "copy_pipe = copy.deepcopy(pipe)\n",
    "smashed_pipe = smash(copy_pipe, smash_config)\n",
    "smashed_pipe.set_progress_bar_config(disable=True)\n",
    "# Optional: tweak model generation parameters for benchmarking\n",
    "smashed_pipe.inference_handler.model_args.update({\"num_inference_steps\": 1, \"guidance_scale\": 0.0})\n",
    "safe_memory_cleanup()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. Evaluate the subsequent model\n",
    "\n",
    "`EvaluationAgent` allows you to compare any kind of models. You can compare a baseline model with a smashed model, or two smashed models, or even two baseline models.\n",
    "\n",
    "In this example, we now evaluate the smashed model. This is done by again calling the `evaluate` method of the EvaluationAgent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "smashed_results = eval_agent.evaluate(smashed_pipe)\n",
    "print(smashed_results)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pruna",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
